{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deeptables.models import deeptable,deepnets,modelset\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
    "x_test = x_test.reshape(10000, 784).astype('float32') / 255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = deeptable.ModelConfig(\n",
    "    nets=['dnn_nets'],\n",
    "    metrics=['AUC'],\n",
    "    optimizer=keras.optimizers.RMSprop(),\n",
    ")\n",
    "dt = deeptable.DeepTable(config=conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Column index of X has been converted: Index(['x_0', 'x_1', 'x_2', 'x_3', 'x_4', 'x_5', 'x_6', 'x_7', 'x_8', 'x_9',\n",
      "       ...\n",
      "       'x_774', 'x_775', 'x_776', 'x_777', 'x_778', 'x_779', 'x_780', 'x_781',\n",
      "       'x_782', 'x_783'],\n",
      "      dtype='object', length=784)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10 class detected, inferred as a [multiclass classification] task\n",
      "Preparing features taken 1.638843059539795s\n",
      "Imputation taken 2.4904778003692627s\n",
      "Categorical encoding taken 1.5020370483398438e-05s\n",
      "Injected a callback [EarlyStopping]. monitor:val_AUC, patience:1, mode:max\n",
      ">>>>>>>>>>>>>>>>>>>>>> Model Desc <<<<<<<<<<<<<<<<<<<<<<< \n",
      "---------------------------------------------------------\n",
      "inputs:\n",
      "---------------------------------------------------------\n",
      "['all_categorical_vars: (0)', 'input_continuous_all: (717)']\n",
      "---------------------------------------------------------\n",
      "embeddings:\n",
      "---------------------------------------------------------\n",
      "input_dims: []\n",
      "output_dims: []\n",
      "dropout: 0.3\n",
      "---------------------------------------------------------\n",
      "dense: dropout: 0\n",
      "batch_normalization: False\n",
      "---------------------------------------------------------\n",
      "concat_embed_dense: shape: (None, 717)\n",
      "---------------------------------------------------------\n",
      "nets: ['dnn_nets']\n",
      "---------------------------------------------------------\n",
      "dnn: input_shape (None, 717), output_shape (None, 64)\n",
      "---------------------------------------------------------\n",
      "stacking_op: add\n",
      "---------------------------------------------------------\n",
      "output: activation: softmax, output_shape: (None, 10), use_bias: True\n",
      "loss: categorical_crossentropy\n",
      "optimizer: RMSprop\n",
      "---------------------------------------------------------\n",
      "\n",
      "Train on 48000 samples, validate on 12000 samples\n",
      "Epoch 1/10\n",
      "48000/48000 [==============================] - 22s 455us/sample - loss: 0.2644 - AUC: 0.9944 - val_loss: 0.1454 - val_AUC: 0.9977\n",
      "Epoch 2/10\n",
      "47744/48000 [============================>.] - ETA: 0s - loss: 0.1101 - AUC: 0.9983Restoring model weights from the end of the best epoch.\n",
      "48000/48000 [==============================] - 16s 334us/sample - loss: 0.1101 - AUC: 0.9983 - val_loss: 0.1178 - val_AUC: 0.9972\n",
      "Epoch 00002: early stopping\n",
      "Model has been saved to:dt_output/dt_20201019 132607_dnn_nets/dnn_nets.h5\n"
     ]
    }
   ],
   "source": [
    "model, history = dt.fit(x_train, y_train, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Column index of X has been converted: Index(['x_0', 'x_1', 'x_2', 'x_3', 'x_4', 'x_5', 'x_6', 'x_7', 'x_8', 'x_9',\n",
      "       ...\n",
      "       'x_774', 'x_775', 'x_776', 'x_777', 'x_778', 'x_779', 'x_780', 'x_781',\n",
      "       'x_782', 'x_783'],\n",
      "      dtype='object', length=784)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'loss': 0.12205703960259756, 'auc': 0.9984204}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dt.evaluate(x_train, y_train, batch_size=512, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
